# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2020, Numenta, Inc.  Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program.  If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------

import logging

from nupic.research.frameworks.pytorch.model_utils import set_random_seed
from nupic.research.frameworks.vernon import interfaces

__all__ = [
    "ExperimentBase",
]


class ExperimentBase(
    interfaces.Experiment,  # Implements
):
    @staticmethod
    def experiment_interface_implemented():
        return True

    def setup_experiment(self, config):
        """
        :param config: Dictionary containing the configuration parameters

            - local_dir: Results path
            - logdir: Directory generated by Ray Tune for this Trial
            - seed: the seed to be used for pytorch, python, and numpy
            - checkpoint_at_init: boolean argument for whether to create a checkpoint
                                  of the initialized model. this differs from
                                  `checkpoint_at_start` for which the checkpoint occurs
                                  after the first epoch of training as opposed to
                                  before it
        """
        self._logger = self.create_logger(config)
        self.logdir = config.get("logdir", None)

        # Configure seed
        self.seed = config.get("seed", 42)
        set_random_seed(self.seed, False)

    @property
    def logger(self):
        return self._logger

    @classmethod
    def create_logger(cls, config):
        log_format = config.get("log_format", logging.BASIC_FORMAT)
        log_level = getattr(logging, config.get("log_level", "INFO").upper())
        console = logging.StreamHandler()
        console.setFormatter(logging.Formatter(log_format))
        logger = logging.getLogger(config.get("name", cls.__name__))
        logger.setLevel(log_level)
        logger.addHandler(console)
        return logger

    def stop_experiment(self):
        pass

    def run_pre_experiment(self):
        pass

    @classmethod
    def insert_pre_experiment_result(cls, result, pre_experiment_result):
        pass

    def get_state(self):
        return {}

    def set_state(self, state):
        pass

    @classmethod
    def get_readable_result(cls, result):
        return result

    @classmethod
    def get_execution_order(cls):
        exp = "ExperimentBase"
        return dict(
            setup_experiment=[exp + ": Initialize logger"],
            run_iteration=[exp + ": Not implemented, must override"],
            run_pre_experiment=[exp + ": No pre_experiment implemented"],
            get_readable_result=[exp + ": Return unfiltered result"],
            insert_pre_experiment_result=[],
            stop_experiment=[],
            should_stop=[exp + ": Not implemented, must override"],
            get_state=[],
            set_state=[],
        )
